package cloud.agileframework.data.common.auth;

import cloud.agileframework.common.constant.Constant;
import cloud.agileframework.common.util.json.JSONUtil;
import cloud.agileframework.common.util.template.VelocityUtil;
import cloud.agileframework.data.common.auth.annotation.AuthData;
import cloud.agileframework.spring.util.SecurityUtil;
import cloud.agileframework.sql.SqlUtil;
import com.alibaba.druid.filter.FilterChain;
import com.alibaba.druid.filter.FilterEventAdapter;
import com.alibaba.druid.proxy.jdbc.CallableStatementProxy;
import com.alibaba.druid.proxy.jdbc.ConnectionProxy;
import com.alibaba.druid.proxy.jdbc.PreparedStatementProxy;
import com.alibaba.druid.proxy.jdbc.ResultSetProxy;
import com.alibaba.druid.proxy.jdbc.StatementProxy;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import com.alibaba.druid.sql.ast.expr.SQLQueryExpr;
import com.alibaba.druid.sql.ast.statement.SQLExprTableSource;
import com.alibaba.druid.sql.ast.statement.SQLJoinTableSource;
import com.alibaba.druid.sql.ast.statement.SQLSelect;
import com.alibaba.druid.sql.ast.statement.SQLSelectQuery;
import com.alibaba.druid.sql.ast.statement.SQLSelectQueryBlock;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.ast.statement.SQLSubqueryTableSource;
import com.alibaba.druid.sql.ast.statement.SQLTableSource;
import com.alibaba.druid.sql.ast.statement.SQLUnionQuery;
import com.alibaba.druid.sql.ast.statement.SQLUnionQueryTableSource;
import com.alibaba.fastjson2.JSON;
import com.alibaba.fastjson2.JSONObject;
import org.apache.commons.lang3.StringUtils;
import org.springframework.security.core.userdetails.UserDetails;

import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class AuthFilter extends FilterEventAdapter {
    private final Map<String, String> filterMapping;
    private final ThreadLocal<AuthData> config = new ThreadLocal<>();

    public void setConfig(AuthData config) {
        this.config.set(config);
    }

    public void clear() {
        this.config.remove();
    }

    public AuthFilter(AuthDataProperties authDataProperties) {
        this.filterMapping = authDataProperties.getFilterMapping();
    }

    @Override
    public boolean statement_execute(FilterChain chain, StatementProxy statement, String sql) throws SQLException {
        return super.statement_execute(chain, statement, parseSql(sql));
    }

    @Override
    public boolean statement_execute(FilterChain chain, StatementProxy statement, String sql, int autoGeneratedKeys) throws SQLException {
        return super.statement_execute(chain, statement, parseSql(sql), autoGeneratedKeys);
    }

    @Override
    public boolean statement_execute(FilterChain chain, StatementProxy statement, String sql, int[] columnIndexes) throws SQLException {
        return super.statement_execute(chain, statement, parseSql(sql), columnIndexes);
    }

    @Override
    public boolean statement_execute(FilterChain chain, StatementProxy statement, String sql, String[] columnNames) throws SQLException {
        return super.statement_execute(chain, statement, parseSql(sql), columnNames);
    }

    @Override
    public int[] statement_executeBatch(FilterChain chain, StatementProxy statement) throws SQLException {
        List<String> batch = statement.getBatchSqlList().stream().map(this::parseSql).collect(Collectors.toList());
        statement.clearBatch();
        for (String sql : batch) {
            statement.addBatch(sql);
        }
        return super.statement_executeBatch(chain, statement);
    }

    @Override
    public ResultSetProxy statement_executeQuery(FilterChain chain, StatementProxy statement, String sql) throws SQLException {
        return super.statement_executeQuery(chain, statement, parseSql(sql));
    }

    @Override
    public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection, String sql) throws SQLException {
        return super.connection_prepareCall(chain, connection, parseSql(sql));
    }

    @Override
    public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection, String sql, int autoGeneratedKeys) throws SQLException {
        return super.connection_prepareStatement(chain, connection, parseSql(sql), autoGeneratedKeys);
    }

    @Override
    public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection, String sql, int resultSetType, int resultSetConcurrency) throws SQLException {
        return super.connection_prepareStatement(chain, connection, parseSql(sql), resultSetType, resultSetConcurrency);
    }

    @Override
    public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection, String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException {
        return super.connection_prepareStatement(chain, connection, parseSql(sql), resultSetType, resultSetConcurrency, resultSetHoldability);
    }

    @Override
    public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection, String sql, int[] columnIndexes) throws SQLException {
        return super.connection_prepareStatement(chain, connection, parseSql(sql), columnIndexes);
    }

    @Override
    public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection, String sql, String[] columnNames) throws SQLException {
        return super.connection_prepareStatement(chain, connection, parseSql(sql), columnNames);
    }

    @Override
    public CallableStatementProxy connection_prepareCall(FilterChain chain, ConnectionProxy connection, String sql) throws SQLException {
        return super.connection_prepareCall(chain, connection, parseSql(sql));
    }

    @Override
    public CallableStatementProxy connection_prepareCall(FilterChain chain, ConnectionProxy connection, String sql, int resultSetType, int resultSetConcurrency) throws SQLException {
        return super.connection_prepareCall(chain, connection, parseSql(sql), resultSetType, resultSetConcurrency);
    }

    @Override
    public CallableStatementProxy connection_prepareCall(FilterChain chain, ConnectionProxy connection, String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException {
        return super.connection_prepareCall(chain, connection, parseSql(sql), resultSetType, resultSetConcurrency, resultSetHoldability);
    }

    private String parseSql(String sql) {
        AuthData authData = config.get();
        if (authData == null
                || (Boolean.FALSE.equals(authData.enable())
                || filterMapping == null
                || filterMapping.isEmpty()
                || filterMapping.keySet().stream().noneMatch(sql::contains))) {
            return sql;
        }
        try {
            UserDetails userDetails = SecurityUtil.currentUser();
            if (authData.group().length > 0) {
                JSONObject jsonObject = (JSONObject) JSONUtil.toJSON(userDetails);
                jsonObject.put(Constant.AgileAbout.AUTH_GROUP, authData.group());
            }

            SQLStatement sqlStatement = SQLUtils.parseSingleMysqlStatement(sql);
            if (sqlStatement instanceof SQLSelectStatement) {
                parsing(((SQLSelectStatement) sqlStatement).getSelect(), authData);
            } else if (sqlStatement instanceof SQLSubqueryTableSource) {
                parsing(((SQLSubqueryTableSource) sqlStatement).getSelect(), authData);
            } else if (sqlStatement instanceof SQLUnionQueryTableSource) {
                parsing(((SQLUnionQueryTableSource) sqlStatement).getUnion(), authData);
            } else if (sqlStatement instanceof SQLJoinTableSource) {
                parsing(((SQLJoinTableSource) sqlStatement).getLeft(), authData);
                parsing(((SQLJoinTableSource) sqlStatement).getRight(), authData);
            }

            return SqlUtil.parserSQL(SQLUtils.toSQLString(sqlStatement), userDetails);
        } catch (Exception e) {
            return sql;
        }
    }

    private void parsing(SQLSelectQuery sqlSelectQuery, AuthData authData) {
        if (sqlSelectQuery instanceof SQLSelectQueryBlock) {
            parsing(((SQLSelectQueryBlock) sqlSelectQuery).getFrom(), authData);
        } else if (sqlSelectQuery instanceof SQLUnionQuery) {
            for (SQLSelectQuery sqlSelectQuery2 : ((SQLUnionQuery) sqlSelectQuery).getChildren()) {
                parsing(sqlSelectQuery2, authData);
            }
        }
    }

    private void parsing(SQLSelect select, AuthData authData) {
        parsing(select.getQueryBlock(), authData);
    }

    private void parsing(SQLTableSource table, AuthData authData) {
        if (table instanceof SQLExprTableSource) {
            String tableProxy = filterMapping.get(((SQLExprTableSource) table).getTableName());
            if (!StringUtils.isBlank(tableProxy)) {

                UserDetails userDetails = SecurityUtil.currentUser();
                JSONObject jsonObject = (JSONObject) JSONUtil.toJSON(userDetails);
                if (authData.group().length > 0) {
                    jsonObject.put(Constant.AgileAbout.AUTH_GROUP, authData.group());
                }
                tableProxy = VelocityUtil.parse(tableProxy, jsonObject);

                String alias = table.getAlias();
                alias = alias == null ? "SUB_ALIAS" : alias;
                SQLExpr sql = SQLUtils.toSQLExpr(tableProxy);
                if (sql instanceof SQLQueryExpr) {
                    SQLUtils.replaceInParent(table, new SQLSubqueryTableSource(((SQLQueryExpr) sql).getSubQuery(), alias));
                } else if (sql instanceof SQLIdentifierExpr) {
                    SQLUtils.replaceInParent(table, new SQLExprTableSource(sql, alias));
                }
            }
        } else if (table instanceof SQLJoinTableSource) {
            parsing(((SQLJoinTableSource) table).getLeft(), authData);
            parsing(((SQLJoinTableSource) table).getRight(), authData);
        } else if (table instanceof SQLSubqueryTableSource) {
            parsing(((SQLSubqueryTableSource) table).getSelect(), authData);
        } else if (table instanceof SQLUnionQueryTableSource) {
            SQLUnionQuery union = ((SQLUnionQueryTableSource) table).getUnion();
            for (SQLSelectQuery sqlSelectQuery : union.getChildren()) {
                parsing(sqlSelectQuery, authData);
            }
        }
    }
}
